Skip to content

Conversation

joecummings
Copy link
Member

@joecummings joecummings commented Sep 19, 2025

What does this PR do?

  1. The PR defaults the dtype of the Trainer and ReferenceModel to bf16
  2. I also slipped in a change which lets training proceed for as long as needed (toggled by the steps param in trainer. DW about it :)

How do we know this works?

The primary way we know this works is by examining the memory that is taken up when running the models. I confirmed that it is about half by looking at nvtop logs. Luckily we also have another easy way to confirm this works b/c when you try to calculate rms norm with an input in bfloat16 and a weight in fp32, it shows this error:

[0] /home/jrcummings/.fbpkg_conda_envs/forge-a7401c7/lib/python3.10/site-packages/torch/nn/functional.py:2920: UserWarning: Mismatch dtype between input and weight: input dtype = c10::BFloat16, weight dtype = float, Cannot dispatch to fused implementation. (Triggered internally at /mnt/code/pytorch/aten/src/ATen/native/layer_norm.cpp:344.)
[0]   return torch.rms_norm(input, normalized_shape, weight, eps)

When this change is merged, the error goes away.

FAQs

  1. Does this work for single device? What an astute question: NO. Distributed APIs handle the conversion to a lower dtype, so if you don't use the distributed APIs it will keep things in fp32. This is annoying, for sure, but not blocking. Keep tracking this issue for more information.
  2. What about training stability? Fair play. While it is common practice to post-train in bf16, people have raised concerns that performance is worse than fp32. See here. Experiments before and after this change don't raise any red flags, but I would consider this part of the ongoing "correctness" work to ensure this doesn't cause any problems. cc @Ritesh1905

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Sep 19, 2025
@joecummings joecummings marked this pull request as ready for review September 19, 2025 21:00
@joecummings joecummings merged commit 605f85f into meta-pytorch:main Sep 19, 2025
5 checks passed
@joecummings joecummings deleted the train-in-bf16 branch September 19, 2025 21:05
@allenwang28
Copy link
Contributor

Does this work for single device? What an astute question: NO.

Isn't the 1.7B example still using single device though?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants